#include <limits.h>
#include <sstream>


#include "socket.h"
#include "log.h"
#include "fd_manage.h"
#include "hook.h"
#include "iomanager.h"

namespace qtch{
static Logger::ptr logger = QTCH_LOG_NAME("system");

Socket::ptr Socket::CreateTCP(Address::ptr address){
    return Socket::ptr(new Socket(address->getFamily(), Socket::Type::TCP, 0));
}

Socket::ptr Socket::CreateUDP(Address::ptr address){
    return Socket::ptr(new Socket(address->getFamily(), Socket::Type::UDP, 0));
}

Socket::ptr Socket::CreateTCPSocket(){
    return Socket::ptr(new Socket(Socket::Family::IPv4, Socket::Type::TCP, 0));
}

Socket::ptr Socket::CreateUDPSocket(){
    return Socket::ptr(new Socket(Socket::Family::IPv4, Socket::Type::UDP, 0));
}

Socket::ptr Socket::CreateTCPSocket6(){
    return Socket::ptr(new Socket(Socket::Family::IPv6, Socket::Type::TCP, 0));
}

Socket::ptr Socket::CreateUDPSocket6(){
    return Socket::ptr(new Socket(Socket::Family::IPv6, Socket::Type::UDP, 0));
}

Socket::ptr Socket::CreateUnixTCPSocket(){
    return Socket::ptr(new Socket(Socket::Family::UNIX, Socket::Type::TCP, 0));
}

Socket::ptr Socket::CreateUnixUDPSocket(){
    return Socket::ptr(new Socket(Socket::Family::UNIX, Socket::Type::UDP, 0));
}


Socket::Socket(int family, int type, int protocol)
        :m_sock(-1)
        ,m_family(family)
        ,m_type(type)
        ,m_protocol(protocol)
        ,m_isConnected(false){
}

Socket::~Socket(){
    close();
}

int64_t Socket::getSendTimeout(){
    FdCtx::ptr ctx = FdMgr::GetInstance()->get(m_sock);
    if(ctx){
        return ctx->getTimeout(SO_SNDTIMEO);
    }
    return -1;
}

void Socket::setSendTimeout(int64_t v){
    timeval t{int(v / 1000), int((v % 1000) * 1000)};
    setOption(SOL_SOCKET,SO_SNDTIMEO,t);
}

int64_t Socket::getRecvTimeout(){
    FdCtx::ptr ctx = FdMgr::GetInstance()->get(m_sock);
    if(ctx){
        return ctx->getTimeout(SO_RCVTIMEO);
    }
    return -1;
}

void Socket::setRecvTimeout(int64_t v){
    timeval t{int(v / 1000), int((v % 1000) * 1000)};
    setOption(SOL_SOCKET,SO_RCVTIMEO,t);
}

bool Socket::getOption(int level, int option, void* result, socklen_t* len){
    if(getsockopt(m_sock, level,option, result,len)){
        QTCH_LOG_DEBUG(logger) << "setOption sock=" << m_sock
            << " level=" << level << " option=" << option
            << " errno=" << errno << " errstr=" << strerror(errno);
        return false;
    }
    return true;
}

bool Socket::setOption(int level, int option, const void* result, socklen_t len){
    if(setsockopt(m_sock, level,option, result,len)){
        QTCH_LOG_DEBUG(logger) << "setOption sock=" << m_sock
            << " level=" << level << " option=" << option
            << " errno=" << errno << " errstr=" << strerror(errno);
        return false;
    }
    return true;
}

Socket::ptr Socket::accept(){
    Socket::ptr sock(new Socket(m_family,m_type,m_protocol));
    int newsock = ::accept(m_sock,nullptr,nullptr);
    if(newsock == -1){
        QTCH_LOG_ERROR(logger) << "accept(" <<m_sock 
            << ") errno=" << errno << strerror(errno);
        return nullptr;
    }
    if(!sock->init(newsock)){
        return nullptr;
    }
    return sock;

}

bool Socket::bind(const Address::ptr addr){
    if(!isValid()){
        newSock();
        if(!isValid()){
            return false;
        }
    }
    if(addr->getFamily() !=  m_family){
        QTCH_LOG_ERROR(logger) << "bind sock.famiy(" << m_family << ")"
            << " addr.family(" << addr->getFamily() << ") not queal, addr=" << addr;
        return false;
    }
    
    if(::bind(m_sock, addr->getAddr(), addr->getAddrLen())){
        QTCH_LOG_ERROR(logger) << "bind error addr=" << addr 
            << " errno=" << errno << " errstr=" << strerror(errno);
        return false;
    }
    getLocalAddress();
    return true;

}

bool Socket::connect(const Address::ptr addr, uint64_t timeout_ms){
    m_remoteAddress = addr;
    if(!isValid()){
        newSock();
        if(!isValid()){
            return false;
        }
    }

    if(addr->getFamily() != m_family){
        QTCH_LOG_ERROR(logger) << "socket::connect sock.family(" << m_family 
            << ") addr.family(" << addr->getFamily() << ") not equal, addr=" <<addr;
        return false;
    }
    if(timeout_ms == (uint64_t)-1){
        if(::connect(m_sock,addr->getAddr(), addr->getAddrLen())){
            QTCH_LOG_ERROR(logger) << "sock=" << m_sock << " connect(" << addr 
                << ") error errno=" << errno << " errstr=" << strerror(errno);
            close();
            return false;
        }
    }
    else{
        if(::connect_with_timeout(m_sock,addr->getAddr(), addr->getAddrLen(),timeout_ms)){
            QTCH_LOG_ERROR(logger) << "sock=" << m_sock << " connect(" << addr 
                << ") error errno=" << errno << " errstr=" << strerror(errno);
            close();
            return false;
        }
    }
    m_isConnected = true;
    getRemoteAddress();
    getLocalAddress();
    return true;
}

bool Socket::reconnect(uint64_t timeout_ms){
    if(!m_remoteAddress){
        QTCH_LOG_ERROR(logger) << "reconnect m_remoteAddress is null";
        return false;
    }
    m_localAddress.reset();
    return connect(m_remoteAddress,timeout_ms);
}

bool Socket::listen(int backlog){
    if(!isValid()){
        QTCH_LOG_ERROR(logger) << "listen error sock=-1";
        return false;
    }
    if(::listen(m_sock,backlog)){
        QTCH_LOG_ERROR(logger) << "listen error errno=" << errno
            << " errstr="<< strerror(errno);
        return false;
    }
    return true;
}

bool Socket::close(){
    if(!m_isConnected && m_sock == -1){
        return true;
    }
    m_isConnected = false;
    if(m_sock != -1){
        ::close(m_sock);
        m_sock = -1;
    }
    return true;
}

int Socket::send(const void* buffer, size_t length, int flags){
    if(isConnected()){
        return ::send(m_sock,buffer,length,flags);
    }
    return -1;
}

int Socket::send(const iovec* buffer, size_t length, int flags){
    if(isConnected()){
        msghdr msg;
        memset(&msg,0,sizeof(msg));
        msg.msg_iov = (iovec*)buffer;
        msg.msg_iovlen = length;
        return ::sendmsg(m_sock,&msg,flags);
    }
    return -1;
}

int Socket::sendTo(const void* buffer, size_t length, const Address::ptr to, int flags){
     if(isConnected()){
        return ::sendto(m_sock,buffer,length, flags, to->getAddr(),to->getAddrLen());
    }
    return -1;
}   

int Socket::sendTo(const iovec* buffer, size_t length, const Address::ptr to, int flags){
    if(isConnected()){
        msghdr msg;
        memset(&msg,0,sizeof(msg));
        msg.msg_iov = (iovec*)buffer;
        msg.msg_iovlen = length;
        msg.msg_name = to->getAddr();
        msg.msg_namelen = to->getAddrLen();
        return ::sendmsg(m_sock,&msg,flags);
    }
    return -1;
}

int Socket::recv(void* buffer, size_t length, int flag){
    if(isConnected()){
        return ::recv(m_sock,buffer,length,flag);
    }
    return -1;
}

int Socket::recv(iovec* buffer, size_t length, int flag){
    if(isConnected()){
        msghdr msg;
        memset(&msg,0,sizeof(msg));
        msg.msg_iov = (iovec*)buffer;
        msg.msg_iovlen = length;
        return ::recvmsg(m_sock,&msg,flag);
    }
    return -1;
}

int Socket::recvFrom(void* buffer, size_t length, Address::ptr from, int flag){
    if(isConnected()){
        socklen_t len = from->getAddrLen();
        return ::recvfrom(m_sock,buffer,length,flag,from->getAddr(),&len);
    }
    return -1;
}

int Socket::recvFrom(iovec* buffer, size_t length, Address::ptr from, int flag){
    if(isConnected()){
        msghdr msg;
        memset(&msg,0,sizeof(msg));
        msg.msg_iov = (iovec*)buffer;
        msg.msg_iovlen = length;
        msg.msg_name = from->getAddr();
        msg.msg_namelen = from->getAddrLen();
        return ::recvmsg(m_sock,&msg,flag);
    }
    return -1;
}

std::ostream& Socket::dump(std::ostream& os)const{
    os << "[Socket sock="<< m_sock
       << " is_connected=" << m_isConnected
       << " famiyl=" << m_family
       << " type=" << m_type
       << " protocol=" << m_protocol;
    if(m_remoteAddress){
        os << " remote_address=" << m_remoteAddress;
    }
    if(m_localAddress){
        os << " local_address=" << m_localAddress;
    }
    os << "]";
    return os;
}

std::string Socket::toString() const{
    std::stringstream ss;
    dump(ss);
    return ss.str();
}

Address::ptr Socket::getRemoteAddress(){
    if(m_remoteAddress){
        return m_remoteAddress;
    }
    Address::ptr result;
    switch(m_family){
    case AF_INET:
        result.reset(new IPv4Address());
        break;
    case AF_INET6:
        result.reset(new IPv6Address());
        break;
    case AF_UNIX:
        result.reset(new UnixAddress());
        break;
    default:
        result.reset(new UnknowAddress(m_family));
        break;
    }
    socklen_t addrlen = result->getAddrLen();
    if(getpeername(m_sock, result->getAddr(),&addrlen)){
        QTCH_LOG_DEBUG(logger) << "getpeername m_sock=" << m_sock
            << " errno=" << errno << " errstr=" << strerror(errno);
        return Address::ptr(new UnknowAddress(m_family));
    }
    QTCH_ASSERT(result->getFamily() == m_family);
    if(m_family == AF_UNIX){
        UnixAddress::ptr addr = std::dynamic_pointer_cast<UnixAddress>(result);
        addr->setAddrLen(addrlen);
    }
    m_remoteAddress = result;
    return m_remoteAddress;
}

Address::ptr Socket::getLocalAddress(){
    if(m_localAddress){
        return m_localAddress;
    }
    Address::ptr result;
    switch(m_family){
    case AF_INET:
        result.reset(new IPv4Address());
        break;
    case AF_INET6:
        result.reset(new IPv6Address());
        break;
    case AF_UNIX:
        result.reset(new UnixAddress());
        break;
    default:
        result.reset(new UnknowAddress(m_family));
        break;
    }
    socklen_t addrlen = result->getAddrLen();
    if(getsockname(m_sock, result->getAddr(),&addrlen)){
        QTCH_LOG_DEBUG(logger) << "getsockname m_sock=" << m_sock
            << " errno=" << errno << " errstr=" << strerror(errno);
        return Address::ptr(new UnknowAddress(m_family));
    }
    QTCH_ASSERT(result->getFamily() == m_family);
    if(m_family == AF_UNIX){
        UnixAddress::ptr addr = std::dynamic_pointer_cast<UnixAddress>(result);
        addr->setAddrLen(addrlen);
    }
    m_localAddress = result;
    return m_localAddress;
}

bool Socket::isValid() const{
    return  m_sock != -1;
}

int Socket::getError(){
    int error = 0;
    socklen_t len = sizeof(error);
    if(!getOption(SOL_SOCKET,SO_ERROR,&error,&len)){
        error = errno;
    }
    return error;
}

bool Socket::cancelRead(){
    if(!isValid()){
        return false;
    }
    return IOManager::getThis()->cancelEvent(m_sock, IOManager::Event::READ);
}

bool Socket::cancelWrite(){
    if(!isValid()){
        return false;
    }
    return IOManager::getThis()->cancelEvent(m_sock, IOManager::Event::WRITE);
}

bool Socket::cancelAccept(){
    if(!isValid()){
        return false;
    }
    return IOManager::getThis()->cancelEvent(m_sock, IOManager::Event::READ);
}

bool Socket::cancelAll(){
    if(!isValid()){
        return false;
    }
    return IOManager::getThis()->cancelEventAll(m_sock);
}

void Socket::initSock(){
    int val = 1;
    setOption(SOL_SOCKET, SO_REUSEADDR,val);
    if(m_type == SOCK_STREAM){
        setOption(IPPROTO_TCP, TCP_NODELAY,val);   
    }
    
}

void Socket::newSock(){
    m_sock = ::socket(m_family,m_type,m_protocol);
    if(isValid()){
        initSock();
    }else{
        QTCH_LOG_ERROR(logger) << "socket("<< m_family
            << ", " << m_type << ", " << m_protocol
            << ") errno=" << errno << " errstr="<< strerror(errno);
    }
}

bool Socket::init(int sock){
    FdCtx::ptr ctx = FdMgr::GetInstance()->get(sock);
    if(ctx && ctx->isSocket() && !ctx->isClosed()){
        m_sock = sock;
        m_isConnected = true;
        
        // 设置初始的超时时间
        setRecvTimeout(5000);
        setSendTimeout(5000);
        
        getRemoteAddress();
        getLocalAddress();
        QTCH_LOG_DEBUG(logger) << "Socket::init( " << sock << " ) " << " sucess, " <<  toString();
        return true;
    }
    return false;
}

std::ostream& operator<<(std::ostream& os, const Socket::ptr sock){
    sock->dump(os);
    return os;
}



struct _SSLInit {
    _SSLInit(){
        SSL_library_init();
        SSL_load_error_strings();
        OpenSSL_add_all_algorithms();
    }
};

static _SSLInit s_init;

SSLSocket::SSLSocket(int family, int type, int protocol) 
    :Socket(family,type,protocol){
}

bool SSLSocket::bind(const Address::ptr addr)  {
    return Socket::bind(addr);
}

bool SSLSocket::connect(const Address::ptr addr, uint64_t timeout_ms )  {
    bool v = Socket::connect(addr, timeout_ms);
    if(v) {
        m_ctx.reset(SSL_CTX_new(SSLv23_client_method()),SSL_CTX_free);
        m_ssl.reset(SSL_new(m_ctx.get()),SSL_free);
        SSL_set_fd(m_ssl.get(),m_sock);
        v = (SSL_connect(m_ssl.get()) == 1);
    }
    return v;
}

bool SSLSocket::listen(int backlog)  {
    return Socket::listen(backlog);
}

bool SSLSocket::close()  {
    return Socket::close();
}

int SSLSocket::send(const void* buffer, size_t length, int flags )  {
    if(m_ssl) {
        return SSL_write(m_ssl.get(),buffer,length);
    }
    return -1;
}

int SSLSocket::send(const iovec* buffer, size_t length, int flags )  {
    if(!m_ssl){
        return -1;
    }
    int total = 0;
    for(size_t i = 0; i < length; ++i){
        if(buffer[i].iov_len == 0){
            continue;
        }
        int tmp = SSL_write(m_ssl.get(),buffer[i].iov_base,buffer[i].iov_len);
        if(tmp <= 0){
            return tmp;
        }
        total += tmp;
        if(tmp != (int)buffer[i].iov_len){
            break;
        }
    }
    return total;
}

int SSLSocket::sendTo(const void* buffer, size_t length, const Address::ptr to, int flags )  {
    QTCH_ASSERT(false);
    return -1;
}

int SSLSocket::sendTo(const iovec* buffer, size_t length, const Address::ptr to, int flags )  {
    QTCH_ASSERT(false);
    return -1;
}

int SSLSocket::recv(void* buffer, size_t length, int flag)  {
    if(m_ssl) {
        return SSL_read(m_ssl.get(),buffer,length);
    }
    return -1;
}

int SSLSocket::recv(iovec* buffer, size_t length, int flag)  {
    if(!m_ssl){
        return -1;
    }
    int total = 0;
    for(size_t i = 0; i < length; ++i){
        if(buffer[i].iov_len == 0){
            continue;
        }
        int tmp = SSL_read(m_ssl.get(),buffer[i].iov_base,buffer[i].iov_len);
        if(tmp <= 0){
            return tmp;
        }
        total += tmp;
        if(tmp != (int)buffer[i].iov_len){
            break;
        }
    }
    return total;
}

int SSLSocket::recvFrom(void* buffer, size_t length, Address::ptr from, int flag)  {
    QTCH_ASSERT(false);
    return -1;
}

int SSLSocket::recvFrom(iovec* buffer, size_t length, Address::ptr from, int flag)  {
    QTCH_ASSERT(false);
    return -1;
}

std::ostream& SSLSocket::dump(std::ostream& os)const  {
    os << "[SSLSocket sock="<< m_sock
       << " is_connected=" << m_isConnected
       << " famiyl=" << m_family
       << " type=" << m_type
       << " protocol=" << m_protocol;
    if(m_remoteAddress){
        os << " remote_address=" << m_remoteAddress;
    }
    if(m_localAddress){
        os << " local_address=" << m_localAddress;
    }
    os << "]";
    return os;
}

bool SSLSocket::init(int sock) {
    bool v = Socket::init(sock);
    if(v){
        m_ssl.reset(SSL_new(m_ctx.get()),SSL_free);
        int rt = SSL_set_fd(m_ssl.get(), sock);
        if(!rt){
            rt = SSL_get_error(m_ssl.get(),rt);
            QTCH_LOG_DEBUG(logger) << "SSL_set_fd error, rt=" << rt;
            return false;
        }
        rt = SSL_accept(m_ssl.get());
        v = ( rt== 1);
        if(!v){
            rt = SSL_get_error(m_ssl.get(),rt);
            QTCH_LOG_DEBUG(logger) << "error rt=" << rt;
        }
    }
    return v;
}

Socket::ptr SSLSocket::accept(){
    SSLSocket::ptr sock(new SSLSocket(m_family,m_type,m_protocol));
    int newsock = ::accept(m_sock,nullptr,nullptr);
    if(newsock == -1){
        QTCH_LOG_ERROR(logger) << "accept(" <<m_sock 
            << ") errno=" << errno << strerror(errno);
        return nullptr;
    }
    sock->m_ctx = m_ctx;
    if(!sock->init(newsock)){
        sock->close();
        return nullptr;
    }
    return sock;
}


SSLSocket::ptr SSLSocket::CreateTCP(Address::ptr address) {
    SSLSocket::ptr sock(new SSLSocket(address->getFamily(), TCP, 0));
    return sock;
}

SSLSocket::ptr SSLSocket::CreateTCPSocket() {
    SSLSocket::ptr sock(new SSLSocket(IPv4, TCP, 0));
    return sock;
}

SSLSocket::ptr SSLSocket::CreateTCPSocket6() {
    SSLSocket::ptr sock(new SSLSocket(IPv6, TCP, 0));
    return sock;
}

bool SSLSocket::loadCertificates(const std::string& cert_file, const std::string& key_file){
    m_ctx.reset(SSL_CTX_new(SSLv23_server_method()),SSL_CTX_free);
    QTCH_LOG_DEBUG(logger) << "loadCertificates cert_file=" << cert_file << " key_file=" << key_file;
    if(SSL_CTX_use_certificate_chain_file(m_ctx.get(),cert_file.c_str()) != 1){
        QTCH_LOG_ERROR(logger) << "SSL_CTX_use_certificate_chain_file(" << cert_file << ") error";
        return false;
    }
    if(SSL_CTX_use_RSAPrivateKey_file(m_ctx.get(),key_file.c_str(),SSL_FILETYPE_PEM) != 1){
        QTCH_LOG_ERROR(logger) << "SSL_CTX_use_PrivateKey_file(" << key_file << ") error";
        return false;
    }
    if(SSL_CTX_check_private_key(m_ctx.get()) != 1){
        QTCH_LOG_ERROR(logger) << "SSL_CTX_check_private_key error";
        return false;
    }
    return true;
}













}
